{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vfch6fQlnG05"
   },
   "source": [
    "# Guide to module validators and fixers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKnAlq0AnQCd"
   },
   "source": [
    "Opacus strives to enable private training of PyTorch models with minimal code changes on the user side. As you might have learnt by following the README and the introductory tutorials, Opacus does this by consuming your model, dataloader, and optimizer and returning wrapped counterparts that can perform privacy-related functions.\n",
    "\n",
    "## Why do I need a Module Validator?\n",
    "While most of the common models work with Opacus, not all of them do.\n",
    "1. Right off the bat, all non-trainable modules (such as `nn.ReLU`, `nn.Tanh`, etc.) and frozen modules (with parameters whose `requires_grad` is set to `False`) are compatible.\n",
    "2. Furthermore, modules should also be able to capture per-sample gradients in order to work under DP setting. `GradSampleModule`'s and implementations offered by `opacus.layers` have this property.\n",
    "3. Some modules such as `BatchNorm` are not DP friendly as a sample's normalized value depends on other samples, and hence are incompatible with Opacus.\n",
    "4. Some other modules such as `InstanceNorm` are DP friendly, except under certain configurations (eg, when `track_running_stats` is On).\n",
    "\n",
    "It is unreasonable to expect you to remember all of this and take care of it. This is why Opacus provides a `ModuleValidator` to take care of this.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DW4DumLD0DWW"
   },
   "source": [
    "## `ModuleValidator` internals\n",
    "The `ModuleValidator` class has two primary class methods `validate()` and `fix()`.\n",
    "\n",
    "As the name suggests, `validate()` validates a given module's compatibility with Opacus by ensuring it is in training mode and is of type `GradSampleModule` (i.e, the module can capture per sample gradients). More importantly, this method also checks the sub-modules and their configurations for compatibility issues (more on this in the next section).\n",
    "\n",
    "The `fix()` method attempts to make the module compatible with Opacus.\n",
    "\n",
    "In Opacus 0.x, the specific checks for each of the supported modules and the necessary replacements were done centrally in the validator with a series of `if` checks. Adding new validation checks and fixes would have necessitated modifying the core Opacus code. In Opacus 1.0, this has been modularised by allowing you to register your own custom validator and fixer.\n",
    "\n",
    "In the rest of the tutorial, we will consider `nn.BatchNorm` as an example and show exactly how to do that."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RcrjckqXFNSG"
   },
   "source": [
    "### Registering validator\n",
    "We know that `BatchNorm` module is not privacy friendly and hence the validator should throw an error, say like this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7pRLzbW2YgtX"
   },
   "outputs": [],
   "source": [
    "def validate_bathcnorm(module):\n",
    "  return [Exception(\"BatchNorm is not supported\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dHCBtWUYLdOz"
   },
   "source": [
    "In order to register the above, all you need to do is decorate the above method as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fgnf6V6Fm730"
   },
   "outputs": [],
   "source": [
    "from opacus.validators import register_module_validator\n",
    "\n",
    "@register_module_validator(\n",
    "    [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]\n",
    ")\n",
    "def validate_bathcnorm(module):\n",
    "  return [Exception(\"BatchNorm is not supported\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mm6x75oIMcU8"
   },
   "source": [
    "That's it! The above will register `validate_bathcnorm()` for all of these modules: `[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]`, and this method will be automatically called along with other validators when you do `privacy_engine.make_private()`.\n",
    "\n",
    "The decorator essentially adds your method to `ModuleValidator`'s register for it to be cycled through during the validation phase.\n",
    "\n",
    "Just one nit bit: it is recommended that you make your validation exceptions as clear as possible. Opacus's validation for the above looks as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "l3wJw5GYOo-W"
   },
   "outputs": [],
   "source": [
    "@register_module_validator(\n",
    "    [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]\n",
    ")\n",
    "def validate(module) -> None:\n",
    "    return [\n",
    "        ShouldReplaceModuleError(\n",
    "            \"BatchNorm cannot support training with differential privacy. \"\n",
    "            \"The reason for it is that BatchNorm makes each sample's normalized value \"\n",
    "            \"depend on its peers in a batch, ie the same sample x will get normalized to \"\n",
    "            \"a different value depending on who else is in its batch. \"\n",
    "            \"Privacy-wise, this means that we would have to put a privacy mechanism there too. \"\n",
    "            \"While it can in principle be done, there are now multiple normalization layers that \"\n",
    "            \"do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm \"\n",
    "            \"are all privacy-safe since they don't have this property.\"\n",
    "            \"We offer utilities to automatically replace BatchNorms to GroupNorms and we will \"\n",
    "            \"release pretrained models to help transition, such as GN-ResNet ie a ResNet using \"\n",
    "            \"GroupNorm, pretrained on ImageNet\"\n",
    "        )\n",
    "    ]. # quite a mouthful, but is super clear! ;)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5HTyhM6WO3DY"
   },
   "source": [
    "### Registering fixer\n",
    "\n",
    "Validating is good, but can we fix the issue when possible? The answer, of course, is yes. And the syntax is pretty much the same as that of validator.\n",
    "\n",
    "`BatchNorm`, for example, can be replaced with `GroupNorm` without any meaningful loss of performance and still being privacy friendly. In Opacus, we do it as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7aXEZspDO0pD"
   },
   "outputs": [],
   "source": [
    "def _batchnorm_to_groupnorm(module) -> nn.GroupNorm:\n",
    "    \"\"\"\n",
    "    Converts a BatchNorm ``module`` to GroupNorm module.\n",
    "    This is a helper function.\n",
    "    Args:\n",
    "        module: BatchNorm module to be replaced\n",
    "    Returns:\n",
    "        GroupNorm module that can replace the BatchNorm module provided\n",
    "    Notes:\n",
    "        A default value of 32 is chosen for the number of groups based on the\n",
    "        paper *Group Normalization* https://arxiv.org/abs/1803.08494\n",
    "    \"\"\"\n",
    "    return nn.GroupNorm(\n",
    "        min(32, module.num_features), module.num_features, affine=module.affine\n",
    "    )\n",
    "\n",
    "from opacus.validators.utils import register_module_fixer\n",
    "\n",
    "@register_module_fixer(\n",
    "    [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm]\n",
    ")\n",
    "def fix(module) -> nn.GroupNorm:\n",
    "    logger.info(\n",
    "        \"The default batch_norm fixer replaces BatchNorm with GroupNorm.\"\n",
    "        \" The batch_norm validator module also offers implementations to replace\"\n",
    "        \" it with InstanceNorm or Identity. Please check them out and override the\"\n",
    "        \" fixer if those are more suitable for your needs.\"\n",
    "    )\n",
    "    return _batchnorm_to_groupnorm(module)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0-MDSXZDP6A-"
   },
   "source": [
    "Opacus does NOT automatically fix the module for you when you call `privacy_engine.make_private()`; it expects the module to be compliant before it is passed in. However, this can easily be done as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mA2kvLGSP5pf"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from opacus.validators import ModuleValidator\n",
    "\n",
    "model = torch.nn.Linear(2,1)\n",
    "if not ModuleValidator.is_valid(model):\n",
    "  model = ModuleValidator.fix(model)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uQw1D8vbe0yz"
   },
   "source": [
    "If you want to use a custom fixer in place of the one provided, you can simply decorate your function using this same decorator. Note that the order of registration matters and the last function to be registered will be the one used.\n",
    "\n",
    "Eg: to only replace `BatchNorm2d` with `InstanceNorm` (while using the default replacement for `BatchNorm1d` and `BatchNorm3d` with `GroupNorm`), you can do:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AmaalNfzROWz"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from opacus.validators import register_module_fixer\n",
    "\n",
    "@register_module_validator([nn.BatchNorm2d])\n",
    "def fix_batchnorm2d(module):\n",
    "  return nn.InstanceNorm2d(module.num_features)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BPSNEozakx-i"
   },
   "source": [
    "Hope this tutorial was helpful! We welcome you to peek into the code under `opacus/validators/` for details. If you have any questions or comments, please don't hesitate to post them on our [forum](https://discuss.pytorch.org/c/opacus/29)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
